{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training Engine\n",
    "This notebook will show you how to use the lightnet engine and related classes to help you setup your training more easily.\n",
    "\n",
    "When training models with pytorch, there is quite a bit of boilerplate code that needs to be written each time.\n",
    "The [lightnet.engine](../api/engine.rst) submodule provides functionality to make it easier to setup your training pipelines and allows to easily perform the same training routine on different networks (with different losses, hyperparameters, etc).\n",
    "\n",
    "In this tutorial, we will discuss the 2 most important classes of this module:\n",
    "\n",
    "- [HyperParameters](#HyperParameters): This class allows to group together all hyperparameters of a certain model, making it easier to perform the same training on various different models\n",
    "- [Engine](#Engine): This class provides basic functionality to reduce the boilerplate code needed to train a model. It has some opinionated features that make it easy to split up training in batches and even mini-batches.\n",
    "\n",
    "Besides these 2 classes, this submodule contains a [LinePlotter](../api/engine.rst#lightnet.engine.LinePlotter) to easily plot data with visdom and a [SchedulerCompositor](../api/engine.rst#lightnet.engine.SchedulerCompositor) that allows to use multiple schedulers throughout your training pipeline."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Basic imports\n",
    "import lightnet as ln\n",
    "import torch\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "import matplotlib.pyplot as plt\n",
    "import brambox as bb\n",
    "\n",
    "# Settings\n",
    "ln.logger.setConsoleLevel('ERROR')             # Only show error log messages\n",
    "bb.logger.setConsoleLevel('ERROR')             # Only show error log messages"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## HyperParameters\n",
    "The [HyperParameters](../api/engine.rst#lightnet.engine.HyperParameters) class is used to store all kinds of different data, and allows to (de)serialize that data to a file for later use.\n",
    "\n",
    "To start using it, you can simply initialize an object, and give it any keyword argument you want to save.\n",
    "Once it is created, you can add more parameters by creating new attributes on the fly (syntax: `object.parameter = value`).  \n",
    "Any attributes that start with an underscore will be marked, so that they are not saved when serializing the data in this class.\n",
    "They will then be stored without as an attribute without the underscore.\n",
    "\n",
    "When saving this class, it will loop through all of its attributes (except those that are marked) and check whether that value has a `state_dict()` function.\n",
    "If it exists, it will store the result of that function instead of the value itself.  \n",
    "The same happens when loading data, except with the `load_state_dict()` function.\n",
    "\n",
    "<div class=\"alert alert-info\">\n",
    "\n",
    "**Note:**\n",
    "\n",
    "You will see that when this object is created, it contains two extra attributes (_batch_, _epoch_).\n",
    "These values get used by the Engine (see next section), to keep track of where the training is, and will always be initialized to **0** when creating a HyperParameters object.\n",
    "\n",
    "</div>\n",
    "\n",
    "<div class=\"alert alert-info\">\n",
    "\n",
    "**Note:**\n",
    "\n",
    "We store the data with a _.state.pt_ extension.\n",
    "This function internally uses the _torch.save()_ function, justifying the final _.pt_ extension, but the _.state_ part is just a convention we are using to differentiate between _HyperParameter.save()_ and _network.save()_\n",
    "\n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "HyperParameters(\n",
      "  bar = ['abc', 'def', 'ghi']\n",
      "  batch = 0\n",
      "  baz* = 6.283\n",
      "  epoch = 0\n",
      "  foo* = 10\n",
      "  network = YoloV2\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "# Create parameters object\n",
    "params = ln.engine.HyperParameters(\n",
    "    network = ln.models.YoloV2(),\n",
    "    _foo = 10,                   # Will not be stored when saving\n",
    ")\n",
    "\n",
    "params.bar = ['abc', 'def']\n",
    "params._baz = 3.1415            # Will not be stored when saving\n",
    "\n",
    "# Access and modify parameters\n",
    "params.bar.append('ghi')\n",
    "params.baz *= 2\n",
    "\n",
    "# Show values\n",
    "print(params)  # Note stars behind values that will not be stored\n",
    "\n",
    "# Save values\n",
    "params.save('parameters.state.pt')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can now use the serialized hyperparameters to reload all previous values.\n",
    "This can be useful to resume an interrupted training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(1.)\n",
      "tensor(1.)\n",
      "HyperParameters(\n",
      "  bar = ['abc', 'def', 'ghi']\n",
      "  batch = 0\n",
      "  epoch = 0\n",
      "  network = YoloV2\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "# Create parameters object\n",
    "params = ln.engine.HyperParameters(\n",
    "    network = ln.models.YoloV2()\n",
    ")\n",
    "\n",
    "# Show a single weight value from network\n",
    "print(params.network.layers[0][0].layers[1].weight.data[0])\n",
    "\n",
    "# load previous state\n",
    "params.load('parameters.state.pt')\n",
    "\n",
    "# Show same weight value as before\n",
    "print(params.network.layers[0][0].layers[1].weight.data[0])\n",
    "\n",
    "# Show values\n",
    "print(params)  # Note extra bar value"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The HyperParameters class also has a static classmethod [from_file](../api/engine.rst#lightnet.engine.HyperParameters.from_file) which allows to load a HyperParameters object from an external file. This is what allows to easily use the same training pipeline with different parameters, networks, etc. To see how this works, check out the next tutorial where we show a real example of training on the Pascal VOC dataset."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Engine\n",
    "The lightnet [Engine](../api/engine.rst#lightnet.engine.Engine) class provides an opinionated framework to reduce the boilerplate code when building up a training pipeline.  \n",
    "We will again refer to the next tutorial about training on Pascal VOC to see an actual example implementation, but we will quickly go over the basic intended usage and all features this engine offers.\n",
    "\n",
    "The Engine is an abstract base class ([ABC](https://docs.python.org/3/library/abc.html#abc.ABC)), which means that you are intended to create your own Engine class which inherits from this one and implement a few methods.  \n",
    "The 2 methods which you are required to implement, are the [process_batch](../api/engine.rst#lightnet.engine.Engine.process_batch) function, which recieves data from your dataloader and should perform the forward and backward passes through your network (and loss), and the [train_batch](../api/engine.rst#lightnet.engine.Engine.train_batch), which should perform the weight update based on the gradients of your parameters.\n",
    "\n",
    "When initializing your Engine, you need to give it a HyperParameters object, a dataloader _(optional)_ and any other keyword arguments.\n",
    "During execution of your engine, you can access any attribute of your HyperParameters as if it were an engine attribute.\n",
    "All keyword arguments passed to the initialization will also be accessible as attributes.\n",
    "The dataloader will be used to get data for training your network and is what is passed on to the process_batch() function.\n",
    "\n",
    "\n",
    "<div class=\"alert alert-info\">\n",
    "\n",
    "**Note:**\n",
    "\n",
    "The reason the training is split up in 2 functions, is to allow to work with mini-batches.  \n",
    "Mini-batches allow to emulate much bigger batches than fit on you can fit on your computer's (GPU) memory.\n",
    "\n",
    "By setting an *engine.mini_batch_size* attribute on the engine, it will call *process_batch()* multiple times before calling *train_batch()*, which effectively means you will accumulate the gradients of the parameters in your network, which is how mini-batches work.\n",
    "\n",
    "If you do not have such an *engine.mini_batch_size* attribute, it will be set to **1** which boils down to having no mini-batches, because each *process_batch()* call will be followed by a *train_batch()* call.\n",
    "\n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Implement engine\n",
    "class CustomEngine(ln.engine.Engine):\n",
    "    def start(self):\n",
    "        \"\"\" Do whatever needs to be done before starting \"\"\"\n",
    "        self.params.to(self.device)  # Casting parameters to a certain device\n",
    "        self.optim.zero_grad()       # Make sure to start with no gradients\n",
    "        self.loss_acc = []           # Loss accumulator\n",
    "        \n",
    "    def process_batch(self, data):\n",
    "        \"\"\" Forward and backward pass \"\"\"\n",
    "        data, target = data  # Unpack\n",
    "        \n",
    "        output = self.network(data)\n",
    "        loss = self.loss(output, target)\n",
    "        loss.backward()\n",
    "                \n",
    "        self.loss_acc.append(loss.item())\n",
    "        \n",
    "    def train_batch(self):\n",
    "        \"\"\" Weight update and logging \"\"\"\n",
    "        self.optim.step()\n",
    "        self.optim.zero_grad()\n",
    "        \n",
    "        batch_loss = sum(self.loss_acc) / len(self.loss_acc)\n",
    "        self.loss_acc = []\n",
    "        self.log(f'Loss: {batch_loss}')\n",
    "        \n",
    "    def quit(self):\n",
    "        if self.batch >= self.max_batches:  # Should probably save weights here\n",
    "            print('Reached end of training')\n",
    "            return True\n",
    "        return False\n",
    "        \n",
    "# Create HyperParameters\n",
    "params = ln.engine.HyperParameters(\n",
    "    network=ln.models.YoloV2(),\n",
    "    mini_batch_size=8,\n",
    "    batch_size=64,\n",
    "    max_batches=128\n",
    ")\n",
    "params.loss = ln.network.loss.RegionLoss(params.network.num_classes, params.network.anchors)\n",
    "params.optim = torch.optim.SGD(params.network.parameters(), lr=0.001)\n",
    "\n",
    "# Create engine\n",
    "engine = CustomEngine(\n",
    "    params, None,              # Dataloader (None) is not valid\n",
    "    device=torch.device('cpu')\n",
    ")\n",
    "\n",
    "# Run engine\n",
    "# engine() -> we will not run this here, as we did not provide a valid dataloader"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The Engine class also allows to define hooks.\n",
    "These are functions that run at various points during the training (*batch_start*, *batch_end*, *epoch_start*, *epoch_end*).\n",
    "You can either specify them by decorating a function with the correct decorator, or by calling the decorator at runtime for your engine."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CustomEngine(ln.engine.Engine):\n",
    "    def start(self):\n",
    "        # Decide of a hook at runtime (backup_rate gets passed to init function)\n",
    "        if self.backup_rate > 0:\n",
    "            self.batch_end(self.backup_rate)(self.backup)  # Call self.backup every after every N batches\n",
    "            \n",
    "    def backup(self):\n",
    "        self.params.save(f'backup-{self.batch}.state.pt')\n",
    "        \n",
    "    @ln.engine.Engine.epoch_end()\n",
    "    def every_epoch(self):\n",
    "        print('END OF EPOCH')\n",
    "        \n",
    "    @ln.engine.Engine.batch_start(10)\n",
    "    def every_ten_batches(self):\n",
    "        # Runs every ten batches (batch 10, 20, 30, ...)\n",
    "        print(f'STARTING BATCH {self.batch+1}')"
   ]
  }
 ],
 "metadata": {
  "hide_input": false,
  "kernelspec": {
   "display_name": "Lightnet-dev",
   "language": "python",
   "name": "ln"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
